/*
 * Copyright (C) 2017 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef FRAMEWORKS_ML_NN_LSTMCELL_H
#define FRAMEWORKS_ML_NN_LSTMCELL_H

#include "ActivationFunctor.h"
#include "HalOperation.h"
#include "tensorflow/lite/kernels/internal/tensor_utils.h"

#include <algorithm>
#include <cmath>

namespace android {
namespace nn {

struct LSTMParams {
    TfLiteFusedActivation activation;
    float cell_clip;
    float proj_clip;
    bool use_cifg;
    bool use_peephole;
    bool use_layer_norm;
    bool use_projection_weight;
    bool use_projection_bias;
    bool merge_outputs;
    bool time_major;
};

struct RunTimeOperandInfo;
struct Shape;

class LSTMCell {
   public:
    LSTMCell(const Operation& operation, std::vector<RunTimeOperandInfo>& operands);

    bool Prepare(const Operation& operation, std::vector<RunTimeOperandInfo>& operands,
                 Shape* scratchShape, Shape* outputStateShape, Shape* cellStateShape,
                 Shape* outputShape);
    bool Eval();

    // Input Tensors of size {n_batch, n_input}
    static constexpr int kInputTensor = 0;

    // Input weight tensors of size: {n_cell, n_input}
    static constexpr int kInputToInputWeightsTensor = 1;  // Optional
    static constexpr int kInputToForgetWeightsTensor = 2;
    static constexpr int kInputToCellWeightsTensor = 3;
    static constexpr int kInputToOutputWeightsTensor = 4;

    // Recurrent weight tensors of size {n_cell, n_output}
    static constexpr int kRecurrentToInputWeightsTensor = 5;  // Optional
    static constexpr int kRecurrentToForgetWeightsTensor = 6;
    static constexpr int kRecurrentToCellWeightsTensor = 7;
    static constexpr int kRecurrentToOutputWeightsTensor = 8;

    // Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
    static constexpr int kCellToInputWeightsTensor = 9;    // Optional
    static constexpr int kCellToForgetWeightsTensor = 10;  // Optional
    static constexpr int kCellToOutputWeightsTensor = 11;  // Optional

    // Gates bias tensors of size {n_cell}
    static constexpr int kInputGateBiasTensor = 12;  // Optional
    static constexpr int kForgetGateBiasTensor = 13;
    static constexpr int kCellGateBiasTensor = 14;
    static constexpr int kOutputGateBiasTensor = 15;

    // Projection weight tensor of size {n_output, n_cell}
    static constexpr int kProjectionWeightsTensor = 16;  // Optional
    // Projection bias tensor of size {n_output}
    static constexpr int kProjectionBiasTensor = 17;  // Optional

    static constexpr int kOutputStateInTensor = 18;
    static constexpr int kCellStateInTensor = 19;

    static constexpr int kActivationParam = 20;
    static constexpr int kCellClipParam = 21;
    static constexpr int kProjClipParam = 22;

    // Layer norm weights tensors of size {n_cell}, representing a diagonal matrix.
    static constexpr int kInputLayerNormWeightsTensor = 23;
    static constexpr int kForgetLayerNormWeightsTensor = 24;
    static constexpr int kCellLayerNormWeightsTensor = 25;
    static constexpr int kOutputLayerNormWeightsTensor = 26;

    // Output tensors.
    static constexpr int kScratchBufferTensor = 0;
    static constexpr int kOutputStateOutTensor = 1;
    static constexpr int kCellStateOutTensor = 2;
    static constexpr int kOutputTensor = 3;

    static constexpr float kLayerNormEpsilon = 1e-8;

    static bool LSTMEvalFloat32(
            const LSTMParams& params, const float* input_buffer, const Shape& input_shape,
            const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer,
            const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer,
            const Shape& input_to_output_weights_shape,
            const float* recurrent_to_input_weights_buffer,
            const float* recurrent_to_forget_weights_buffer,
            const float* recurrent_to_cell_weights_buffer,
            const float* recurrent_to_output_weights_buffer,
            const Shape& recurrent_to_output_weights_shape,
            const float* cell_to_input_weights_buffer, const float* cell_to_forget_weights_buffer,
            const float* cell_to_output_weights_buffer, const float* aux_input_buffer,
            const float* aux_input_to_input_weights, const float* aux_input_to_forget_weights,
            const float* aux_input_to_cell_weights, const float* aux_input_to_output_weights,
            const float* input_gate_bias_buffer, const float* forget_gate_bias_buffer,
            const float* cell_bias_buffer, const float* output_gate_bias_buffer,
            const float* projection_weights_buffer, const float* projection_bias_buffer,
            const float* output_state_in_buffer, const float* cell_state_in_buffer,
            const float* input_layer_norm_weights_buffer,
            const float* forget_layer_norm_weights_buffer,
            const float* cell_layer_norm_weights_buffer,
            const float* output_layer_norm_weights_buffer, float* output_state_out_buffer,
            float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer,
            bool timeMajor = true, bool forwardSequence = true);

    static bool LSTMEvalFloat16(
            const LSTMParams& params, const _Float16* input_buffer, const Shape& input_shape,
            const _Float16* input_to_input_weights_buffer,
            const _Float16* input_to_forget_weights_buffer,
            const _Float16* input_to_cell_weights_buffer,
            const _Float16* input_to_output_weights_buffer,
            const Shape& input_to_output_weights_shape,
            const _Float16* recurrent_to_input_weights_buffer,
            const _Float16* recurrent_to_forget_weights_buffer,
            const _Float16* recurrent_to_cell_weights_buffer,
            const _Float16* recurrent_to_output_weights_buffer,
            const Shape& recurrent_to_output_weights_shape,
            const _Float16* cell_to_input_weights_buffer,
            const _Float16* cell_to_forget_weights_buffer,
            const _Float16* cell_to_output_weights_buffer, const _Float16* aux_input_buffer,
            const _Float16* aux_input_to_input_weights, const _Float16* aux_input_to_forget_weights,
            const _Float16* aux_input_to_cell_weights, const _Float16* aux_input_to_output_weights,
            const _Float16* input_gate_bias_buffer, const _Float16* forget_gate_bias_buffer,
            const _Float16* cell_bias_buffer, const _Float16* output_gate_bias_buffer,
            const _Float16* projection_weights_buffer, const _Float16* projection_bias_buffer,
            const _Float16* output_state_in_buffer, const _Float16* cell_state_in_buffer,
            const _Float16* input_layer_norm_weights_buffer,
            const _Float16* forget_layer_norm_weights_buffer,
            const _Float16* cell_layer_norm_weights_buffer,
            const _Float16* output_layer_norm_weights_buffer, _Float16* output_state_out_buffer,
            _Float16* cell_state_out_buffer, _Float16* output_buffer,
            _Float16* scratch_buffer_buffer, bool timeMajor = true, bool forwardSequence = true);

    static bool LSTMStep(
            const LSTMParams& params, const float* input_buffer, const Shape& input_shape,
            const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer,
            const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer,
            const Shape& input_to_output_weights_shape,
            const float* recurrent_to_input_weights_buffer,
            const float* recurrent_to_forget_weights_buffer,
            const float* recurrent_to_cell_weights_buffer,
            const float* recurrent_to_output_weights_buffer,
            const Shape& recurrent_to_output_weights_shape,
            const float* cell_to_input_weights_buffer, const float* cell_to_forget_weights_buffer,
            const float* cell_to_output_weights_buffer, const float* aux_input_buffer,
            const float* aux_input_to_input_weights, const float* aux_input_to_forget_weights,
            const float* aux_input_to_cell_weights, const float* aux_input_to_output_weights,
            const float* input_gate_bias_buffer, const float* forget_gate_bias_buffer,
            const float* cell_bias_buffer, const float* output_gate_bias_buffer,
            const float* projection_weights_buffer, const float* projection_bias_buffer,
            const float* output_state_in_buffer, const float* cell_state_in_buffer,
            const float* input_layer_norm_weights_buffer,
            const float* forget_layer_norm_weights_buffer,
            const float* cell_layer_norm_weights_buffer,
            const float* output_layer_norm_weights_buffer, float* output_state_out_buffer,
            float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer);

    static bool CheckInputTensorDimensions(
            const RunTimeOperandInfo* input_, const RunTimeOperandInfo* input_to_input_weights,
            const RunTimeOperandInfo* input_to_forget_weights,
            const RunTimeOperandInfo* input_to_cell_weights,
            const RunTimeOperandInfo* input_to_output_weights,
            const RunTimeOperandInfo* recurrent_to_input_weights,
            const RunTimeOperandInfo* recurrent_to_forget_weights,
            const RunTimeOperandInfo* recurrent_to_cell_weights,
            const RunTimeOperandInfo* recurrent_to_output_weights,
            const RunTimeOperandInfo* cell_to_input_weights,
            const RunTimeOperandInfo* cell_to_forget_weights,
            const RunTimeOperandInfo* cell_to_output_weights,
            const RunTimeOperandInfo* input_gate_bias, const RunTimeOperandInfo* forget_gate_bias,
            const RunTimeOperandInfo* cell_bias, const RunTimeOperandInfo* output_gate_bias,
            const RunTimeOperandInfo* projection_weights, const RunTimeOperandInfo* projection_bias,
            const RunTimeOperandInfo* input_layer_norm_weights,
            const RunTimeOperandInfo* forget_layer_norm_weights,
            const RunTimeOperandInfo* cell_layer_norm_weights,
            const RunTimeOperandInfo* output_layer_norm_weights, uint32_t n_input,
            uint32_t n_output, uint32_t n_cell, LSTMParams* params);

   private:
    LSTMParams params_;
    const RunTimeOperandInfo* input_;

    const RunTimeOperandInfo* input_to_input_weights_;
    const RunTimeOperandInfo* input_to_forget_weights_;
    const RunTimeOperandInfo* input_to_cell_weights_;
    const RunTimeOperandInfo* input_to_output_weights_;

    const RunTimeOperandInfo* recurrent_to_input_weights_;
    const RunTimeOperandInfo* recurrent_to_forget_weights_;
    const RunTimeOperandInfo* recurrent_to_cell_weights_;
    const RunTimeOperandInfo* recurrent_to_output_weights_;

    const RunTimeOperandInfo* cell_to_input_weights_;
    const RunTimeOperandInfo* cell_to_forget_weights_;
    const RunTimeOperandInfo* cell_to_output_weights_;

    const RunTimeOperandInfo* input_gate_bias_;
    const RunTimeOperandInfo* forget_gate_bias_;
    const RunTimeOperandInfo* cell_bias_;
    const RunTimeOperandInfo* output_gate_bias_;

    const RunTimeOperandInfo* projection_weights_;
    const RunTimeOperandInfo* projection_bias_;

    const RunTimeOperandInfo* output_state_in_;
    const RunTimeOperandInfo* cell_state_in_;

    const RunTimeOperandInfo* input_layer_norm_weights_;
    const RunTimeOperandInfo* forget_layer_norm_weights_;
    const RunTimeOperandInfo* cell_layer_norm_weights_;
    const RunTimeOperandInfo* output_layer_norm_weights_;

    RunTimeOperandInfo* output_state_out_;
    RunTimeOperandInfo* cell_state_out_;
    RunTimeOperandInfo* output_;

    RunTimeOperandInfo* scratch_buffer_;
};

}  // namespace nn
}  // namespace android

#endif  // FRAMEWORKS_ML_NN_LSTMCELL_H