/*
 * Copyright (C) 2018 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#define LOG_TAG "Operations"

#include "OperationResolver.h"
#include "RNN.h"

namespace android {
namespace nn {
namespace unidirectional_sequence_rnn {

constexpr uint32_t kNumInputs = 7;
constexpr uint32_t kInputTensor = 0;
constexpr uint32_t kWeightsTensor = 1;
constexpr uint32_t kRecurrentWeightsTensor = 2;
constexpr uint32_t kBiasTensor = 3;
constexpr uint32_t kHiddenStateTensor = 4;
constexpr uint32_t kActivationParam = 5;
constexpr uint32_t kTimeMajorParam = 6;

constexpr uint32_t kNumOutputs = 1;
constexpr uint32_t kOutputTensor = 0;

namespace {

template <typename T>
void transposeFirstTwoDims(const T* input, const Shape& inputShape, T* output) {
    const uint32_t firstDimSize = getSizeOfDimension(inputShape, 0);
    const uint32_t secondDimSize = getSizeOfDimension(inputShape, 1);
    const uint32_t inputSize = getSizeOfDimension(inputShape, 2);
    for (int f = 0; f < firstDimSize; ++f) {
        for (int s = 0; s < secondDimSize; ++s) {
            for (int i = 0; i < inputSize; ++i) {
                const uint32_t inputIndex = f * secondDimSize * inputSize + s * inputSize + i;
                const uint32_t outputIndex = s * firstDimSize * inputSize + f * inputSize + i;
                output[outputIndex] = input[inputIndex];
            }
        }
    }
}

template <typename T>
bool executeTyped(IOperationExecutionContext* context) {
    const T* input = context->getInputBuffer<T>(kInputTensor);
    Shape inputShape = context->getInputShape(kInputTensor);
    const T* weights = context->getInputBuffer<T>(kWeightsTensor);
    Shape weightsShape = context->getInputShape(kWeightsTensor);
    const T* recurrentWeights = context->getInputBuffer<T>(kRecurrentWeightsTensor);
    Shape recurrentWeightsShape = context->getInputShape(kRecurrentWeightsTensor);
    const T* bias = context->getInputBuffer<T>(kBiasTensor);
    const T* hiddenState = context->getInputBuffer<T>(kHiddenStateTensor);
    int32_t activation = context->getInputValue<int32_t>(kActivationParam);

    T* output = context->getOutputBuffer<T>(kOutputTensor);
    Shape outputShape = context->getOutputShape(kOutputTensor);

    int32_t timeMajor = context->getInputValue<int32_t>(kTimeMajorParam);
    // If the input tensors are not in time major format, we transpose the first
    // two dimensions, and set input and output pointers to temporary vectors
    // which are transposed back after the RNN is applied.
    std::vector<T> inputTransposed;
    std::vector<T> outputTransposed;
    if (!timeMajor) {
        // Convert input and output to time major format.
        inputTransposed.resize(getNumberOfElements(inputShape));
        outputTransposed.resize(getNumberOfElements(outputShape));
        transposeFirstTwoDims(input, inputShape, inputTransposed.data());
        input = inputTransposed.data();
        output = outputTransposed.data();
        std::swap(inputShape.dimensions[0], inputShape.dimensions[1]);
        std::swap(outputShape.dimensions[0], outputShape.dimensions[1]);
    }

    const uint32_t maxTime = getSizeOfDimension(inputShape, 0);
    const uint32_t batchSize = getSizeOfDimension(inputShape, 1);
    const uint32_t inputSize = getSizeOfDimension(inputShape, 2);
    const uint32_t numUnits = getSizeOfDimension(weightsShape, 0);

    // A shape at a fixed step (removed time dimension).
    Shape fixedTimeInputShape = inputShape;
    fixedTimeInputShape.dimensions.resize(2);
    fixedTimeInputShape.dimensions[0] = inputShape.dimensions[1];
    fixedTimeInputShape.dimensions[1] = inputShape.dimensions[2];

    for (int i = 0; i < maxTime; ++i) {
        RNN::RNNStep<T>(input, fixedTimeInputShape, hiddenState, bias, weights, weightsShape,
                        recurrentWeights, recurrentWeightsShape, activation, output);
        input += batchSize * inputSize;
        hiddenState = output;
        output += batchSize * numUnits;
    }

    if (!timeMajor) {
        transposeFirstTwoDims(outputTransposed.data(), outputShape,
                              context->getOutputBuffer<T>(kOutputTensor));
    }
    return true;
}

}  // namespace

bool validate(const IOperationValidationContext* context) {
    NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
    NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
    OperandType inputType = context->getInputType(kInputTensor);
    if (inputType != OperandType::TENSOR_FLOAT16 && inputType != OperandType::TENSOR_FLOAT32) {
        LOG(ERROR) << "Unsupported input operand type for UNIDIRECTIONAL_SEQUENCE_RNN op: "
                   << toString(inputType);
        return false;
    }
    NN_RET_CHECK(validateInputTypes(context, {inputType, inputType, inputType, inputType, inputType,
                                              OperandType::INT32, OperandType::INT32}));
    NN_RET_CHECK(validateOutputTypes(context, {inputType}));
    return validateHalVersion(context, HalVersion::V1_2);
}

bool prepare(IOperationExecutionContext* context) {
    Shape input = context->getInputShape(kInputTensor);
    Shape weights = context->getInputShape(kWeightsTensor);
    Shape recurrentWeights = context->getInputShape(kRecurrentWeightsTensor);
    Shape bias = context->getInputShape(kBiasTensor);
    Shape hiddenState = context->getInputShape(kHiddenStateTensor);

    int32_t timeMajor = context->getInputValue<int32_t>(kTimeMajorParam);
    NN_RET_CHECK(timeMajor == 0 || timeMajor == 1);
    const uint32_t batchSize =
            timeMajor ? getSizeOfDimension(input, 1) : getSizeOfDimension(input, 0);
    const uint32_t maxTime =
            timeMajor ? getSizeOfDimension(input, 0) : getSizeOfDimension(input, 1);
    const uint32_t numUnits = getSizeOfDimension(weights, 0);
    const uint32_t inputSize = getSizeOfDimension(input, 2);

    NN_RET_CHECK_EQ(getNumberOfDimensions(input), 3);
    NN_RET_CHECK_EQ(getNumberOfDimensions(weights), 2);
    NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentWeights), 2);
    NN_RET_CHECK_EQ(getNumberOfDimensions(bias), 1);
    NN_RET_CHECK_EQ(getNumberOfDimensions(hiddenState), 2);

    NN_RET_CHECK_EQ(inputSize, getSizeOfDimension(weights, 1));
    NN_RET_CHECK_EQ(numUnits, getSizeOfDimension(bias, 0));
    NN_RET_CHECK_EQ(numUnits, getSizeOfDimension(recurrentWeights, 0));
    NN_RET_CHECK_EQ(numUnits, getSizeOfDimension(recurrentWeights, 1));
    NN_RET_CHECK_EQ(batchSize, getSizeOfDimension(hiddenState, 0));
    NN_RET_CHECK_EQ(numUnits, getSizeOfDimension(hiddenState, 1));

    Shape output = context->getOutputShape(kOutputTensor);
    output.dimensions[0] = timeMajor ? maxTime : batchSize;
    output.dimensions[1] = timeMajor ? batchSize : maxTime;
    output.dimensions[2] = numUnits;

    return context->setOutputShape(kOutputTensor, output);
}

bool execute(IOperationExecutionContext* context) {
    if (context->getInputType(kInputTensor) == OperandType::TENSOR_FLOAT16) {
        executeTyped<_Float16>(context);
    } else {
        executeTyped<float>(context);
    }
    return true;
}

}  // namespace unidirectional_sequence_rnn

NN_REGISTER_OPERATION(UNIDIRECTIONAL_SEQUENCE_RNN, "UNIDIRECTIONAL_SEQUENCE_RNN",
                      unidirectional_sequence_rnn::validate, unidirectional_sequence_rnn::prepare,
                      unidirectional_sequence_rnn::execute);

}  // namespace nn
}  // namespace android