#ifndef FRAMEWORKS_ML_NN_QUANTIZEDLSTM_H
#define FRAMEWORKS_ML_NN_QUANTIZEDLSTM_H

#include "HalOperation.h"
#include "OperationsUtils.h"

#include <vector>

namespace android {
namespace nn {

struct RunTimeOperandInfo;

class QuantizedLSTMCell {
   public:
    QuantizedLSTMCell(const android::hardware::neuralnetworks::V1_2::Operation& operation,
                      std::vector<RunTimeOperandInfo>& operands);

    static bool prepare(const android::hardware::neuralnetworks::V1_2::Operation& operation,
                        std::vector<RunTimeOperandInfo>& operands, Shape* cellStateShape,
                        Shape* outputShape);
    bool eval();

    // Inputs:
    static constexpr int kInputTensor = 0;
    // Input weight tensors of size: {n_cell, n_input}
    static constexpr int kInputToInputWeightsTensor = 1;
    static constexpr int kInputToForgetWeightsTensor = 2;
    static constexpr int kInputToCellWeightsTensor = 3;
    static constexpr int kInputToOutputWeightsTensor = 4;

    // Recurrent weight tensors of size {n_cell, n_output}
    static constexpr int kRecurrentToInputWeightsTensor = 5;
    static constexpr int kRecurrentToForgetWeightsTensor = 6;
    static constexpr int kRecurrentToCellWeightsTensor = 7;
    static constexpr int kRecurrentToOutputWeightsTensor = 8;

    // Gates bias tensors of size {n_cell}
    static constexpr int kInputGateBiasTensor = 9;
    static constexpr int kForgetGateBiasTensor = 10;
    static constexpr int kCellGateBiasTensor = 11;
    static constexpr int kOutputGateBiasTensor = 12;

    static constexpr int kPrevCellStateTensor = 13;
    static constexpr int kPrevOutputTensor = 14;

    // Outputs:
    static constexpr int kCellStateOutTensor = 0;
    static constexpr int kOutputTensor = 1;

   private:
    const RunTimeOperandInfo* input_;

    const RunTimeOperandInfo* inputToInputWeights_;
    const RunTimeOperandInfo* inputToForgetWeights_;
    const RunTimeOperandInfo* inputToCellWeights_;
    const RunTimeOperandInfo* inputToOutputWeights_;

    const RunTimeOperandInfo* recurrentToInputWeights_;
    const RunTimeOperandInfo* recurrentToForgetWeights_;
    const RunTimeOperandInfo* recurrentToCellWeights_;
    const RunTimeOperandInfo* recurrentToOutputWeights_;

    const RunTimeOperandInfo* inputGateBias_;
    const RunTimeOperandInfo* forgetGateBias_;
    const RunTimeOperandInfo* cellGateBias_;
    const RunTimeOperandInfo* outputGateBias_;

    const RunTimeOperandInfo* prevCellState_;
    const RunTimeOperandInfo* prevOutput_;

    RunTimeOperandInfo* cellStateOut_;
    RunTimeOperandInfo* output_;

    void concatenateWeights(const std::vector<uint32_t>& weightsDims, uint8_t* weights);
    void concatenateBiases(uint32_t outputSize, int32_t* bias);
};

}  // namespace nn
}  // namespace android

#endif  // FRAMEWORKS_ML_NN_QUANTIZEDLSTM_H